package io.sqrtqiezi.spark.lagou

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

object KNNIris {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .master("local")
      .appName("knn iris data")
      .getOrCreate()

    val K = 9 // 定义 K 的值

    // 读取已经分类的数据
    val irisDF = spark.read
      .option("inferSchema", "true")
      .option("header", "true")
      .csv("lagou-data/iris.csv")

    // 读取待分类的数据
    val iris2DF = spark.read
      .option("inferSchema", "true")
      .option("header", "true")
      .csv("lagou-data/iris2.csv")

    val df1 = iris2DF.crossJoin(irisDF) // 对两个集合取笛卡尔积
      // 计算两个点的欧氏距离
      .withColumn("distance",
        expr("sqrt(pow(SepalLengthCm2 - SepalLengthCm, 2) +" +
          "pow(SepalWidthCm2 - SepalWidthCm, 2) +" +
          "pow(PetalLengthCm2 - PetalLengthCm, 2) +" +
          "pow(PetalWidthCm2 - PetalWidthCm, 2))"))

    val windowSpec = Window.partitionBy("Id2")
      .orderBy("distance")
      .rowsBetween(Window.unboundedPreceding, Window.currentRow)

    // 利用开窗函数，分别对未分类的点和已分类的点的距离进行排序
    val df2 = df1.select(
      col("Id2"),
      col("Species"),
      col("distance"),
      row_number().over(windowSpec).alias("rank")
    )

    val df3 = df2.where(col("rank") <= K) // 保留距离最近的 K 个点
      .groupBy("Id2", "Species") // 分别计算已经分类的个数
      .count()

    // 取最距离最近最多的分类，作为未分类点的分类
    val result = df3.rdd
      .map(row => (row.getInt(0), (row.getString(1), row.getLong(2))))
      .groupByKey
      .mapValues(values => values.toArray.sortWith((left, right) => left._2 > right._2)(0)._1)

    for ((id, species) <- result.collect) {
      println(s"$id -> $species")
    }
  }
}
